import math
from typing import Callable, Any, Tuple

from centralized_verification.envs.continuous_grid_world import ContinuousGridWorld
from centralized_verification.envs.fast_grid_world import FastGridWorld
from centralized_verification.envs.particle_momentum import ParticleMomentum

LabelExtractor = Callable[[Any], Tuple[int]]


def FullObsGridWorldLabelExtractor2Agents(env: FastGridWorld):
    def label_extractor(state_num):
        a1, a2 = state_num
        (a1x, a1y), (a2x, a2y) = env.grid_posns[a1], env.grid_posns[a2]

        return 22 + a1x - a2x, 9 + a1y - a2y

    return label_extractor


def ParticleMomentumLabelExtractor2Agents(env: ParticleMomentum):
    def label_extractor(state):
        rx1, rx2, _, _ = state

        return rx1, rx2

    return label_extractor


def ContinuousGridWorldLabelExtractor(env: ContinuousGridWorld):
    def label_extractor(state):
        a1x, a1y, a2x, a2y = state

        return 22 + math.floor(a1x) - math.floor(a2x), 9 + math.floor(a1y) - math.floor(a2y)

    return label_extractor
